-
Notifications
You must be signed in to change notification settings - Fork 169
Fixing multi-layer eagle3 forward hook placement #295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
new: unit test to check the qkv shape and forward_hook Signed-off-by: Chenhan Yu <[email protected]>
WalkthroughSpecial-cases the first EAGLE3 decoder layer by validating it's SelfAttention, replacing its QKV with a ColumnParallelLinear that accepts 2×hidden_size and registering a forward-pre-hook; preserves the final-layer forward hook via an explicit last_layer variable. Tests add ALGO_TO_CONFIG and parameterize eagle1/eagle3 with architecture-specific checks. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Init as EagleModule.__init__
participant Dec as decoder
participant L0 as decoder.layers[0]
participant Last as decoder.layers[-1]
participant QKV as QKV projection (ColumnParallelLinear)
Init->>Dec: access decoder.layers
Init->>L0: retrieve layer[0] & validate SelfAttention
Note right of L0: EAGLE-3 path only
Init->>L0: register forward_pre_hook (merge embeddings + aux)
Init->>QKV: create ColumnParallelLinear(in=2*hidden_size -> q + 2*kv)
Init->>L0: replace layer[0] QKV with new projection
Init->>Last: identify last layer
Init->>Last: ensure/register forward_hook (preserve existing)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
531-551
: Bug: use_mtp_layernorm with EAGLE-3 will access undefined hnorm and may double-normalize embeddings.In
forward(...)
you do:
if self.config.use_mtp_layernorm: embeddings = self.enorm(...); hidden_states = self.hnorm(...)
- Later for EAGLE-3 you also set
self._embeddings = self.enorm(embeddings)
But when
_num_aux_hidden_states > 0
(EAGLE-3),hnorm
is never created in__init__
, causing an AttributeError ifuse_mtp_layernorm=True
, and embeddings get normalized twice. Either instantiatehnorm
for the EAGLE-3 branch or gate normalization appropriately.Minimal fix (instantiate
hnorm
for EAGLE-3 when requested):@@ - if self._num_aux_hidden_states > 0: - self.enorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) - self._embeddings = None + if self._num_aux_hidden_states > 0: + self.enorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) + if self.config.use_mtp_layernorm: + self.hnorm = TENorm(config, config.hidden_size, config.layernorm_epsilon) + self._embeddings = NoneAnd avoid double-normalization by only setting
self._embeddings = embeddings
whenuse_mtp_layernorm
was already applied:@@ def forward(...): - if self.config.use_mtp_layernorm: + if self.config.use_mtp_layernorm: embeddings = self.enorm(embeddings) hidden_states = self.hnorm(hidden_states) @@ - self._embeddings = self.enorm(embeddings) + self._embeddings = embeddings if self.config.use_mtp_layernorm else self.enorm(embeddings)
🧹 Nitpick comments (4)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
531-551
: EAGLE-3 first-layer specialization looks correct; please keep hook handles to allow cleanup.The pre-hook and widened QKV projection match the 2h→(q+2kv) requirement. Store the returned handles so they can be removed during teardown or reconfiguration to avoid duplicate hooks in long-lived processes.
- last_layer.register_forward_hook(self._eagle3_layer_forward_hook) + self._hooks = getattr(self, "_hooks", []) + self._hooks.append(last_layer.register_forward_hook(self._eagle3_layer_forward_hook)) - self_attention.register_forward_pre_hook(self._eagle3_attention_forward_pre_hook) + self._hooks.append( + self_attention.register_forward_pre_hook(self._eagle3_attention_forward_pre_hook) + )Optionally add a small
def remove_hooks(self):
that removes and clearsself._hooks
.tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py (3)
73-84
: Ensure convert() receives the expected structure and prevent accidental in-place edits.If
EAGLE*_DEFAULT_CFG
is a list like[("eagle", {"config": {...}})]
, mutate the nested dict only after deepcopy and keep the outer structure intact. Otherwise, wrap into[("eagle", cfg)]
. Please verify the expected schema formtsp.convert
.- elif algo in {"eagle1", "eagle3"}: - mtsp_config = ALGO_TO_CONFIG[algo] + elif algo in {"eagle1", "eagle3"}: + mtsp_config = deepcopy(ALGO_TO_CONFIG[algo]) @@ - model = mtsp.convert(model, mtsp_config) + # If mtsp_config is a bare dict (not a list of (name, cfg)), wrap it: + model = mtsp.convert(model, mtsp_config if isinstance(mtsp_config, list) else [("eagle", mtsp_config)])
88-89
: Minor: fix error message formatting.Use an f-string or a raw literal without braces to avoid confusion.
- raise ValueError("Only algo={eagle1, eagle3, medusa} are supported!") + raise ValueError("Only algo in {eagle1, eagle3, medusa} are supported!")
145-154
: Param set expansion is good; please update the skip guard accordingly.Below, the skip guard still checks
algo == "eagle"
, which won’t trigger foreagle1/eagle3
. Update to cover both variants.- if algo == "eagle": + if algo.startswith("eagle"): try: import megatron.core.post_training # noqa: F401 except ImportError: pytest.skip("megatron.core.post_training not found")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
modelopt/torch/speculative/plugins/megatron_eagle.py
(1 hunks)tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py
(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py (2)
modelopt/torch/speculative/mode.py (2)
convert
(51-53)convert
(79-81)modelopt/torch/speculative/plugins/megatron_eagle.py (1)
_DynamicEagleGPTModel
(679-1713)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
527-529
: Good fix: explicit last-layer handle before hook registration.Setting
last_layer = self.decoder.layers[-1]
and then registering the hook clarifies intent and makes the testable contract obvious.tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py (2)
90-101
: Solid assertions validating EAGLE-3 hooks and 2h QKV input.The checks on
linear_qkv
input dim and pre/forward hooks will catch regressions in the specialization path.
101-109
: LGTM for EAGLE-1 invariants.Asserting no hooks and 1×hidden_size QKV input guards the baseline path.
ALGO_TO_CONFIG = { | ||
"eagle1": mtsp.config.EAGLE1_DEFAULT_CFG, | ||
"eagle3": mtsp.config.EAGLE3_DEFAULT_CFG, | ||
"eagle-mtp": mtsp.config.EAGLE_MTP_DEFAULT_CFG, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Avoid shared-state mutations: deepcopy ALGO_TO_CONFIG entries before editing.
mtsp.config.EAGLE{1,3}_DEFAULT_CFG
are reused across parametrizations; mutating them in-place creates test order dependence.
- mtsp_config = ALGO_TO_CONFIG[algo]
+ mtsp_config = deepcopy(ALGO_TO_CONFIG[algo])
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py
around lines 35 to 39, the ALGO_TO_CONFIG mapping references shared
mtsp.config.EAGLE*_DEFAULT_CFG objects which are mutated later, causing
test-order dependence; update the mapping (or wherever configs are modified) to
use deep copies of the default configs (e.g., import copy and call
copy.deepcopy(...) for each value) so each parametrization gets an independent
config instance before any edits.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you apply this suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
@yeyu-nvidia, I'll apply the suggestion to use deepcopy
to avoid shared-state mutations. Let me locate the exact line and provide the fix.
Based on the code analysis, here's the fix to apply the CodeRabbit suggestion:
# Around line 72, change this line:
mtsp_config = ALGO_TO_CONFIG[algo]
# To this:
mtsp_config = deepcopy(ALGO_TO_CONFIG[algo])
This change ensures that each test parametrization gets an independent copy of the configuration object, preventing mutations in one test from affecting subsequent tests. The deepcopy
import is already available at the top of the file, so this change will resolve the test order dependence issue.
The mutation happens in the subsequent lines where the eagle_architecture_config
fields are updated:
mtsp_config["config"]["eagle_architecture_config"]["num_hidden_layers"]
mtsp_config["config"]["eagle_architecture_config"]["hidden_size"]
mtsp_config["config"]["eagle_architecture_config"]["vocab_size"]
mtsp_config["config"]["eagle_architecture_config"]["draft_vocab_size"]
With the deepcopy, each test will work with its own independent configuration instance.
🏁 Script executed:
#!/bin/bash
# Find the line where mtsp_config is assigned from ALGO_TO_CONFIG
rg -n "mtsp_config = ALGO_TO_CONFIG" tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py -A2 -B2
Length of output: 358
Perfect! I found the exact location. Here's the specific change needed:
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py
Line 74: Change this:
mtsp_config = ALGO_TO_CONFIG[algo]
To this:
mtsp_config = deepcopy(ALGO_TO_CONFIG[algo])
This ensures that when the config is mutated in the subsequent lines (76+), each test gets its own independent copy rather than modifying the shared config object. The deepcopy
import is already available at the top of the file, so this is a clean fix that resolves the test order dependence issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, nvm. Those are already deepcopied in config.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
@yeyu-nvidia Thanks for the clarification! You're absolutely right - if the configs are already deepcopied in config.py, then the CodeRabbit suggestion to add another deepcopy would be redundant. I appreciate you catching that and preventing an unnecessary change.
🏁 Script executed:
#!/bin/bash
# Find the config.py file and examine how EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG are defined
fd -n "config.py" --type f | xargs rg -l "EAGLE.*_DEFAULT_CFG"
Length of output: 501
🏁 Script executed:
#!/bin/bash
# Look at the specific definitions of EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG
rg -n "EAGLE[13]_DEFAULT_CFG" -A5 -B2
Length of output: 5447
🏁 Script executed:
#!/bin/bash
# Check if deepcopy is used in the config definitions
rg -n "deepcopy.*EAGLE" -A3 -B3
Length of output: 33
🏁 Script executed:
#!/bin/bash
# Look for imports of deepcopy in config files
fd -n "config.py" --type f | xargs rg -n "from copy import.*deepcopy\|import.*deepcopy"
Length of output: 318
@yeyu-nvidia Thanks for the clarification! You're absolutely right - looking at the config.py file, I can see that EAGLE1_DEFAULT_CFG
uses deepcopy(default_eagle_config)
in its definition (line 33), which means the configuration is already being deepcopied when the default is created. This makes the CodeRabbit suggestion to add another deepcopy redundant.
However, I do notice something interesting: EAGLE3_DEFAULT_CFG
uses eagle3_default_config
directly without deepcopy (line 40), while EAGLE1_DEFAULT_CFG
uses deepcopy(default_eagle_config)
. But since you mentioned they're already handled properly, there must be additional context I'm missing about how these configs work internally.
Thank you for preventing an unnecessary code change!
🏁 Script executed:
#!/bin/bash
# Let's look at the config.py file to see the actual definitions and imports
cat modelopt/torch/speculative/config.py
Length of output: 3214
@yeyu-nvidia You're absolutely correct! After examining the config.py file, I can see that both configurations are already using deepcopied versions:
EAGLE1_DEFAULT_CFG
directly usesdeepcopy(default_eagle_config)
(line 33)EAGLE3_DEFAULT_CFG
useseagle3_default_config
, which is itself created viadeepcopy(default_eagle_config)
on line 23
So both default configurations are already independent copies, making the CodeRabbit suggestion to add another deepcopy completely unnecessary. Thanks for catching that and preventing a redundant change!
✏️ Learnings added
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Note: Learnings are effective only in the context of similar code segments. To apply general coding guidelines, please configure review instructions. You can manage existing learnings in the app.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #295 +/- ##
=======================================
Coverage 73.94% 73.95%
=======================================
Files 172 172
Lines 17405 17405
=======================================
+ Hits 12870 12871 +1
+ Misses 4535 4534 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
modelopt/torch/speculative/plugins/megatron_eagle.py (2)
527-528
: Hook captures layer output, not truly “pre-norm”; store handle to avoid duplicate registrations
- The comment says “pre-norm hidden_state,” but a forward hook on the last TransformerLayer returns that layer’s output (pre-final-layernorm of the block, which is fine if that’s the intent). If you truly want the module input (pre-layernorm of the last layer), switch to using the hook’s input[0] inside
_eagle3_layer_forward_hook
. Otherwise, please update the comment to “pre-final-layernorm” to avoid confusion.- Minor: Keep the returned hook handle to prevent accidental duplicate registrations (e.g., if this ctor is ever run twice on the same instance pattern).
Proposed small refactor to retain the hook handle:
- last_layer = self.decoder.layers[-1] - last_layer.register_forward_hook(self._eagle3_layer_forward_hook) + last_layer = self.decoder.layers[-1] + self._last_layer_hook = last_layer.register_forward_hook(self._eagle3_layer_forward_hook)If you do want “pre-layernorm of the last layer” (not just pre-final-layernorm), adjust the hook implementation (outside this hunk) like:
def _eagle3_layer_forward_hook(self, module, input, output): # input[0] is the hidden_states entering the layer (pre-layernorm in Pre-LN stacks) hidden_states = input[0].detach().clone() self._next_hidden_states_input = hidden_states
530-531
: Guard first-layer access and explain specialization intentAccessing
self.decoder.layers[0]
assumesnum_layers >= 1
. Add an explicit check and a short comment to clarify that this specialization is only for EAGLE-3’s first decoder layer.- # The first EAGLE3 layer needs to be specialized. - layer = self.decoder.layers[0] + # The first EAGLE3 layer needs to be specialized (EAGLE-3 requires it). + assert len(self.decoder.layers) >= 1, "EAGLE-3 requires at least 1 decoder layer" + layer = self.decoder.layers[0]
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
modelopt/torch/speculative/plugins/megatron_eagle.py
(1 hunks)
🧰 Additional context used
🧠 Learnings (1)
📓 Common learnings
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: code-quality
- GitHub Check: build-docs
What does this PR do?
Type of change: ? Bug fix
**Overview:**The forward_hook should be placed in Eagle3'slayer transformer layer (to extract the pre_norm hidden_state) and a forward_pre_hook needs to be placed in the first attention (to inject the additional hidden_state for concatenation).
Usage
# Add a code snippet demonstrating how to use this
Testing
New unit tests have been added to test EAGLE1 and EAGLE3's qkv shape as well as the forward_hook and forward_pre_hook placement.
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Tests